import copy as copy_lib
import functools
import gzip
import json
import numpy as np
import pickle
from scipy import sparse as sp
from typing import *

from . import residue_constants as rc
from .data_ops import NumpyDict


def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False):
    if deepcopy:

        def decorator(f):
            cached_func = functools.lru_cache(maxsize, typed)(f)

            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                return copy_lib.deepcopy(cached_func(*args, **kwargs))

            return wrapper

    elif copy:

        def decorator(f):
            cached_func = functools.lru_cache(maxsize, typed)(f)

            @functools.wraps(f)
            def wrapper(*args, **kwargs):
                return copy_lib.copy(cached_func(*args, **kwargs))

            return wrapper

    else:
        decorator = functools.lru_cache(maxsize, typed)
    return decorator


@lru_cache(maxsize=8, deepcopy=True)
def load_pickle_safe(path: str) -> Dict[str, Any]:
    def load(path):
        assert path.endswith(".pkl") or path.endswith(
            ".pkl.gz"
        ), f"bad suffix in {path} as pickle file."
        open_fn = gzip.open if path.endswith(".gz") else open
        with open_fn(path, "rb") as f:
            return pickle.load(f)

    ret = load(path)
    ret = uncompress_features(ret)
    return ret


@lru_cache(maxsize=8, copy=True)
def load_pickle(path: str) -> Dict[str, Any]:
    def load(path):
        assert path.endswith(".pkl") or path.endswith(
            ".pkl.gz"
        ), f"bad suffix in {path} as pickle file."
        open_fn = gzip.open if path.endswith(".gz") else open
        with open_fn(path, "rb") as f:
            return pickle.load(f)

    ret = load(path)
    ret = uncompress_features(ret)
    return ret


def correct_template_restypes(feature):
    """Correct template restype to have the same order as residue_constants."""
    feature = np.argmax(feature, axis=-1).astype(np.int32)
    new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
    feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
    return feature


def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict:
    feature["msa"] = feature["msa"].astype(np.uint8)
    if "num_alignments" in feature:
        feature.pop("num_alignments")
    make_all_seq_key = lambda k: f"{k}_all_seq" if not k.endswith("_all_seq") else k
    return {make_all_seq_key(k): v for k, v in feature.items()}


def to_dense_matrix(spmat_dict: NumpyDict):
    spmat = sp.coo_matrix(
        (spmat_dict["data"], (spmat_dict["row"], spmat_dict["col"])),
        shape=spmat_dict["shape"],
        dtype=np.float32,
    )
    return spmat.toarray()


FEATS_DTYPE = {"msa": np.int32}


def uncompress_features(feats: NumpyDict) -> NumpyDict:
    if "sparse_deletion_matrix_int" in feats:
        v = feats.pop("sparse_deletion_matrix_int")
        v = to_dense_matrix(v)
        feats["deletion_matrix"] = v
    return feats


def filter(feature: NumpyDict, **kwargs) -> NumpyDict:
    assert len(kwargs) == 1, f"wrong usage of filter with kwargs: {kwargs}"
    if "desired_keys" in kwargs:
        feature = {k: v for k, v in feature.items() if k in kwargs["desired_keys"]}
    elif "required_keys" in kwargs:
        for k in kwargs["required_keys"]:
            assert k in feature, f"cannot find required key {k}."
    elif "ignored_keys" in kwargs:
        feature = {k: v for k, v in feature.items() if k not in kwargs["ignored_keys"]}
    else:
        raise AssertionError(f"wrong usage of filter with kwargs: {kwargs}")
    return feature


def compress_features(features: NumpyDict):
    change_dtype = {
        "msa": np.uint8,
    }
    sparse_keys = ["deletion_matrix_int"]

    compressed_features = {}
    for k, v in features.items():
        if k in change_dtype:
            v = v.astype(change_dtype[k])
        if k in sparse_keys:
            v = sp.coo_matrix(v, dtype=v.dtype)
            sp_v = {"shape": v.shape, "row": v.row, "col": v.col, "data": v.data}
            k = f"sparse_{k}"
            v = sp_v
        compressed_features[k] = v
    return compressed_features
